637fba
@@ -18,10 +18,13 @@
 
 package org.apache.hadoop.hive.ql.optimizer.spark;
 
+import java.util.Collection;
+import java.util.EnumSet;
 import java.util.List;
 import java.util.Set;
 import java.util.Stack;
 
+import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import org.apache.hadoop.hive.common.ObjectPair;
@@ -50,6 +53,8 @@
 import org.apache.hadoop.hive.ql.plan.ReduceSinkDesc;
 import org.apache.hadoop.hive.ql.stats.StatsUtils;
 
+import static org.apache.hadoop.hive.ql.plan.ReduceSinkDesc.ReducerTraits.UNIFORM;
+
 /**
  * SetSparkReducerParallelism determines how many reducers should
  * be run for a given reduce sink, clone from SetReducerParallelism.
@@ -120,41 +125,64 @@
public Object process(Node nd, Stack<Node> stack,
           }
         }
 
-        long numberOfBytes = 0;
-
-        if (useOpStats) {
-          // we need to add up all the estimates from the siblings of this reduce sink
-          for (Operator<? extends OperatorDesc> sibling
-              : sink.getChildOperators().get(0).getParentOperators()) {
-            if (sibling.getStatistics() != null) {
-              numberOfBytes = StatsUtils.safeAdd(numberOfBytes, sibling.getStatistics().getDataSize());
-              if (LOG.isDebugEnabled()) {
-                LOG.debug("Sibling " + sibling + " has stats: " + sibling.getStatistics());
-              }
-            } else {
-              LOG.warn("No stats available from: " + sibling);
-            }
-          }
-        } else if (parentSinks.isEmpty()) {
-          // Not using OP stats and this is the first sink in the path, meaning that
-          // we should use TS stats to infer parallelism
-          for (Operator<? extends OperatorDesc> sibling
-              : sink.getChildOperators().get(0).getParentOperators()) {
-            Set<TableScanOperator> sources =
-                OperatorUtils.findOperatorsUpstream(sibling, TableScanOperator.class);
-            for (TableScanOperator source : sources) {
-              if (source.getStatistics() != null) {
-                numberOfBytes = StatsUtils.safeAdd(numberOfBytes, source.getStatistics().getDataSize());
+        if (useOpStats || parentSinks.isEmpty()) {
+          long numberOfBytes = 0;
+          if (useOpStats) {
+            // we need to add up all the estimates from the siblings of this reduce sink
+            for (Operator<? extends OperatorDesc> sibling
+                : sink.getChildOperators().get(0).getParentOperators()) {
+              if (sibling.getStatistics() != null) {
+                numberOfBytes = StatsUtils.safeAdd(numberOfBytes, sibling.getStatistics().getDataSize());
                 if (LOG.isDebugEnabled()) {
-                  LOG.debug("Table source " + source + " has stats: " + source.getStatistics());
+                  LOG.debug("Sibling " + sibling + " has stats: " + sibling.getStatistics());
                 }
               } else {
-                LOG.warn("No stats available from table source: " + source);
+                LOG.warn("No stats available from: " + sibling);
               }
             }
+          } else {
+            // Not using OP stats and this is the first sink in the path, meaning that
+            // we should use TS stats to infer parallelism
+            for (Operator<? extends OperatorDesc> sibling
+                : sink.getChildOperators().get(0).getParentOperators()) {
+              Set<TableScanOperator> sources =
+                  OperatorUtils.findOperatorsUpstream(sibling, TableScanOperator.class);
+              for (TableScanOperator source : sources) {
+                if (source.getStatistics() != null) {
+                  numberOfBytes = StatsUtils.safeAdd(numberOfBytes, source.getStatistics().getDataSize());
+                  if (LOG.isDebugEnabled()) {
+                    LOG.debug("Table source " + source + " has stats: " + source.getStatistics());
+                  }
+                } else {
+                  LOG.warn("No stats available from table source: " + source);
+                }
+              }
+            }
+            LOG.debug("Gathered stats for sink " + sink + ". Total size is "
+                + numberOfBytes + " bytes.");
+          }
+
+          // Divide it by 2 so that we can have more reducers
+          long bytesPerReducer = context.getConf().getLongVar(HiveConf.ConfVars.BYTESPERREDUCER) / 2;
+          int numReducers = Utilities.estimateReducers(numberOfBytes, bytesPerReducer,
+              maxReducers, false);
+
+          getSparkMemoryAndCores(context);
+          if (sparkMemoryAndCores != null &&
+              sparkMemoryAndCores.getFirst() > 0 && sparkMemoryAndCores.getSecond() > 0) {
+            // warn the user if bytes per reducer is much larger than memory per task
+            if ((double) sparkMemoryAndCores.getFirst() / bytesPerReducer < 0.5) {
+              LOG.warn("Average load of a reducer is much larger than its available memory. " +
+                  "Consider decreasing hive.exec.reducers.bytes.per.reducer");
+            }
+
+            // If there are more cores, use the number of cores
+            numReducers = Math.max(numReducers, sparkMemoryAndCores.getSecond());
           }
-          LOG.debug("Gathered stats for sink " + sink + ". Total size is "
-              + numberOfBytes + " bytes.");
+          numReducers = Math.min(numReducers, maxReducers);
+          LOG.info("Set parallelism for reduce sink " + sink + " to: " + numReducers +
+              " (calculated)");
+          desc.setNumReducers(numReducers);
         } else {
           // Use the maximum parallelism from all parent reduce sinks
           int numberOfReducers = 0;
@@ -164,30 +192,14 @@
public Object process(Node nd, Stack<Node> stack,
           desc.setNumReducers(numberOfReducers);
           LOG.debug("Set parallelism for sink " + sink + " to " + numberOfReducers
               + " based on its parents");
-          return false;
         }
-
-        // Divide it by 2 so that we can have more reducers
-        long bytesPerReducer = context.getConf().getLongVar(HiveConf.ConfVars.BYTESPERREDUCER) / 2;
-        int numReducers = Utilities.estimateReducers(numberOfBytes, bytesPerReducer,
-            maxReducers, false);
-
-        getSparkMemoryAndCores(context);
-        if (sparkMemoryAndCores != null &&
-            sparkMemoryAndCores.getFirst() > 0 && sparkMemoryAndCores.getSecond() > 0) {
-          // warn the user if bytes per reducer is much larger than memory per task
-          if ((double) sparkMemoryAndCores.getFirst() / bytesPerReducer < 0.5) {
-            LOG.warn("Average load of a reducer is much larger than its available memory. " +
-                "Consider decreasing hive.exec.reducers.bytes.per.reducer");
-          }
-
-          // If there are more cores, use the number of cores
-          numReducers = Math.max(numReducers, sparkMemoryAndCores.getSecond());
+        final Collection<ExprNodeDesc.ExprNodeDescEqualityWrapper> keyCols =
+            ExprNodeDesc.ExprNodeDescEqualityWrapper.transform(desc.getKeyCols());
+        final Collection<ExprNodeDesc.ExprNodeDescEqualityWrapper> partCols =
+            ExprNodeDesc.ExprNodeDescEqualityWrapper.transform(desc.getPartitionCols());
+        if (keyCols != null && keyCols.equals(partCols)) {
+          desc.setReducerTraits(EnumSet.of(UNIFORM));
         }
-        numReducers = Math.min(numReducers, maxReducers);
-        LOG.info("Set parallelism for reduce sink " + sink + " to: " + numReducers +
-            " (calculated)");
-        desc.setNumReducers(numReducers);
       }
     } else {
       LOG.info("Number of reducers for sink " + sink + " was already determined to be: " + desc.getNumReducers());
